-
Notifications
You must be signed in to change notification settings - Fork 162
FP8 Block quantize onnx export support #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds optional FP8 blockwise quantization end-to-end: new FP8_SAGE_DEFAULT_CONFIG, runtime-driven quantize_mha flag, propagation of per-tensor vs dynamic block shapes through diffusers attention and ONNX symbolic/export paths, TensorQuantizer and ScaledE4M3 support for block sizes, and ONNX helpers for blockwise FP8 quantize/dequantize. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant QC as QuantConfig
participant Attn as _QuantAttention/FP8SDPA
participant Sym as FP8SDPA.symbolic
participant Export as export_fp8_mha
participant BlockHelpers as _fp8_block_quant/_fp8_block_dequant
User->>QC: load FP8_SAGE_DEFAULT_CONFIG
User->>Attn: forward(query, key, value, ...)
Attn->>Attn: detect dynamic vs non-dynamic Q/K/V
Attn->>Attn: compute q/k/v block_shapes via _get_block_sizes_list()
Attn->>Sym: symbolic(..., q_block_shape, k_block_shape, v_block_shape)
Sym->>Export: export_fp8_mha(..., q_block_shape, k_block_shape, v_block_shape)
alt block shapes present
Export->>BlockHelpers: _fp8_block_quantize(Q/K/V, block_shape)
BlockHelpers-->>Export: quantized uint8 + scales
Export->>BlockHelpers: _fp8_block_dequantize(..., block_shape)
BlockHelpers-->>Export: dequantized tensors
else no block shapes
Export->>Export: per-tensor FP8 quantize/dequantize
end
Export-->>User: ONNX graph with FP8 (blockwise or per-tensor)
sequenceDiagram
autonumber
participant TensorQ as TensorQuantizer
participant SE4 as ScaledE4M3Function
participant Export as export_fp8
participant BlockHelpers as _fp8_block_quant/_fp8_block_dequant
TensorQ->>TensorQ: _get_block_sizes_list(_original_input_shape)
TensorQ->>SE4: forward(x, scale, amax, block_sizes_list, ...)
alt ONNX export path
SE4->>Export: export_fp8(..., amax=None|float, block_sizes)
opt block_sizes provided
Export->>BlockHelpers: _fp8_block_quantize/_fp8_block_dequantize
end
else eager fake-quant
SE4->>SE4: apply fake-quant with block sizes
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Signed-off-by: Jingyu Xin <[email protected]>
e4d1775
to
071f167
Compare
Signed-off-by: Jingyu Xin <[email protected]>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #324 +/- ##
==========================================
- Coverage 73.82% 73.71% -0.12%
==========================================
Files 172 172
Lines 17438 17471 +33
==========================================
+ Hits 12874 12879 +5
- Misses 4564 4592 +28 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (11)
examples/diffusers/quantization/config.py (1)
39-39
: Fix spacing inconsistency in configuration.The configuration has inconsistent spacing after commas. Line 39 has a missing space after the comma between
(4, 3)
and"block_sizes"
.- "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}}, + "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}},modelopt/torch/quantization/export_onnx.py (2)
237-263
: Consider adding validation for block_sizes parameter.The new
_fp8_block_quantize
function should validate the structure and values of theblock_sizes
parameter to prevent runtime errors during ONNX export.Add validation at the beginning of the function:
def _fp8_block_quantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, trt_high_precision_dtype: str, block_sizes: list, ): """Helper Function for Quantization.""" + if not isinstance(block_sizes, list) or not block_sizes: + raise ValueError(f"block_sizes must be a non-empty list, got {block_sizes}") + if not all(isinstance(b, int) and b > 0 for b in block_sizes): + raise ValueError(f"All block sizes must be positive integers, got {block_sizes}") + output_shape = sym_help._get_tensor_sizes(inputs)
534-535
: Fix typo in comment.- # We cannot do block quant for the softmax's output + # We cannot do block quant for the softmax's outputmodelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
653-653
: Remove trailing whitespace from blank lines.These blank lines contain unnecessary whitespace which violates Python style guidelines.
- + Args: shape: The tensor shape to use for conversion (can be tuple or torch.Size) - + Returns: List of block sizes for each dimension, or None if block_sizes is None - + Example:Also applies to: 656-656, 659-659
961-962
: Consider thread-safety for the _original_input_shape attribute.Setting and deleting
_original_input_shape
as a temporary attribute could cause issues in multi-threaded scenarios where the same quantizer is used by multiple threads simultaneously.Consider using a context manager or local variable approach instead:
- setattr(self, "_original_input_shape", inputs.shape) - inputs = self._process_for_blockquant(inputs) + original_shape = inputs.shape + inputs = self._process_for_blockquant(inputs) + # Pass original_shape to methods that need itAlternatively, consider storing it in a thread-local storage if multi-threading support is required.
modelopt/torch/quantization/plugins/diffusers.py (6)
117-124
: Fix assertion logic for mixed dynamic/non-dynamic quantizersThe current implementation requires all QKV quantizers to be either dynamic or non-dynamic together. However, the logic flow is problematic - if they're all non-dynamic, scales are computed, but if any is dynamic, it asserts all must be dynamic. This creates a rigid constraint that may not be necessary for all use cases.
Consider refactoring to handle mixed cases more gracefully:
- if not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic: - q_quantized_scale = self.q_bmm_quantizer._get_amax(query) - k_quantized_scale = self.k_bmm_quantizer._get_amax(key) - v_quantized_scale = self.v_bmm_quantizer._get_amax(value) - else: - assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type" - q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None + # Compute scales for non-dynamic quantizers, set None for dynamic ones + q_quantized_scale = None if self.q_bmm_quantizer._dynamic else self.q_bmm_quantizer._get_amax(query) + k_quantized_scale = None if self.k_bmm_quantizer._dynamic else self.k_bmm_quantizer._get_amax(key) + v_quantized_scale = None if self.v_bmm_quantizer._dynamic else self.v_bmm_quantizer._get_amax(value) + + # Optionally validate consistency if needed + dynamic_states = [self.q_bmm_quantizer._dynamic, self.k_bmm_quantizer._dynamic, self.v_bmm_quantizer._dynamic] + if len(set(dynamic_states)) > 1: + # Log warning or handle mixed dynamic states if necessary + pass
122-122
: Fix line length violationLine 122 exceeds the 120 character limit (149 characters).
- assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type" + assert (self.q_bmm_quantizer._dynamic and + self.k_bmm_quantizer._dynamic and + self.v_bmm_quantizer._dynamic), "QKV QDQS must be in the same type"
144-146
: Remove trailing whitespaceLine 145 has trailing whitespace after the comma.
q_block_sizes, - k_block_sizes, + k_block_sizes, v_block_sizes,
231-233
: Inconsistent default values for scale parametersThe scale parameters have inconsistent default values in the symbolic method signature (
float | None = 1.0
) which doesn't match the forward method where they default toNone
.- q_quantized_scale: float | None = 1.0, - k_quantized_scale: float | None = 1.0, - v_quantized_scale: float | None = 1.0, + q_quantized_scale: float | None = None, + k_quantized_scale: float | None = None, + v_quantized_scale: float | None = None,
200-202
: Consider using TypeAlias for block shape type consistencyThe block shape parameters use
list | None
type annotations repeatedly. Consider defining a type alias for better maintainability and consistency.Add at the top of the file after imports:
from typing import TypeAlias BlockShape: TypeAlias = list[int] | NoneThen update the signatures:
- q_block_shape: list | None = None, - k_block_shape: list | None = None, - v_block_shape: list | None = None, + q_block_shape: BlockShape = None, + k_block_shape: BlockShape = None, + v_block_shape: BlockShape = None,Also applies to: 236-238
126-128
: Add validation for block sizes consistencyThe code retrieves block sizes from quantizers but doesn't validate that they're compatible with the actual tensor shapes or the quantization configuration.
Consider adding validation to ensure block sizes are appropriate:
# Get block sizes lists for each quantizer if needed q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) + + # Validate block sizes if dynamic quantization is enabled + if self.q_bmm_quantizer._dynamic and q_block_sizes: + for dim, block_size in enumerate(q_block_sizes): + if block_size > 1 and query.shape[dim] % block_size != 0: + raise ValueError(f"Query dimension {dim} (size {query.shape[dim]}) is not divisible by block size {block_size}") + # Similar validation for k and v can be added
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/diffusers/quantization/config.py
(1 hunks)examples/diffusers/quantization/quantize.py
(1 hunks)modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(5 hunks)modelopt/torch/quantization/plugins/diffusers.py
(6 hunks)modelopt/torch/quantization/tensor_quant.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
_get_amax
(540-549)_get_block_sizes_list
(651-672)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
trt_high_precision_dtype
(407-409)trt_high_precision_dtype
(412-413)block_sizes
(289-291)block_sizes
(294-296)amax
(233-238)amax
(241-252)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (7)
amax
(233-238)amax
(241-252)block_sizes
(289-291)block_sizes
(294-296)trt_high_precision_dtype
(407-409)trt_high_precision_dtype
(412-413)forward
(902-1004)modelopt/torch/quantization/export_onnx.py (1)
export_fp8
(321-342)
🪛 GitHub Actions: Code Quality
modelopt/torch/quantization/plugins/diffusers.py
[error] 122-122: E501 Line too long (149 > 120).
[error] 134-136: mypy: Item 'str' of 'str | float | None' has no attribute 'shape' [union-attr] (lines 134-136).
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
[warning] 653-653: W293 Blank line contains whitespace.
[warning] 656-656: W293 Blank line contains whitespace.
[warning] 659-659: W293 Blank line contains whitespace.
[warning] 959-959: RUF003 Comment contains ambiguous RIGHT SINGLE QUOTATION MARK ’. Did you mean `` (GRAVE ACCENT)?
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
🔇 Additional comments (6)
examples/diffusers/quantization/quantize.py (1)
942-942
: LGTM! Runtime configuration for MHA quantization is properly handled.The change from
QuantizationConfig.quantize_mha
toquant_config.quantize_mha
correctly uses the runtime configuration for MHA quantization. This aligns with the overall pattern of making MHA quantization configurable.modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
651-673
: LGTM! Well-designed helper method for block size conversion.The
_get_block_sizes_list
method is well-implemented with proper handling of both positive and negative dimension indices. The documentation with examples clearly explains its purpose.modelopt/torch/quantization/export_onnx.py (1)
298-298
: Incorrect — default block_sizes is not exercised by current callersCall sites found in modelopt/torch/quantization/export_onnx.py pass block_sizes into _fp8_block_dequantize, so the hardcoded [1,1,128,1] default isn't used; no change required.
Likely an incorrect or invalid review comment.
modelopt/torch/quantization/tensor_quant.py (1)
415-415
: No action required — parse_args('i' → 'is') is backward-compatible.'is' accepts both integer and integer-list, so existing callers that pass integers remain valid (repo tests call scaled_e4m3(..., 4, 3)).
modelopt/torch/quantization/plugins/diffusers.py (2)
87-148
: Good implementation of block quantization propagation!The changes effectively extend the FP8 quantization path to support dynamic block quantization:
- Proper handling of dynamic vs non-dynamic quantizers
- Clean propagation of block shape parameters through the call stack
- Maintains backward compatibility with existing code
221-221
: parse_args: 'is' is correct for int-list (or None)Matches existing usage in modelopt/torch/quantization/tensor_quant.py — the three "is" entries correctly map the three block-shape parameters to int[] | None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
959-963
: Ensure _original_input_shape is always cleaned up (wrap in try/finally)State can leak if an exception occurs between setting and deleting _original_input_shape. Wrap the forward section after setup in try/finally.
- if ( + cleanup_original_shape = False + if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): # Reshape is required if the logic isnt handled in the simulation kernel self._setup_for_blockquant(inputs) setattr(self, "_original_input_shape", inputs.shape) + cleanup_original_shape = True inputs = self._process_for_blockquant(inputs) - outputs = inputs + try: + outputs = inputs ... - if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") + finally: + if cleanup_original_shape and hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape")Also applies to: 1002-1003
🧹 Nitpick comments (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
681-699
: Comment is misleading within dynamic-only branchThe comment says “Double scale Block quantization, including dynamic and static block quantization” but this branch executes only when type == "dynamic". Tighten the comment to avoid confusion.
- # Double scale Block quantization, including dynamic and static block quantization + # Dynamic double-scale block quantization pathmodelopt/torch/quantization/plugins/diffusers.py (2)
117-132
: Align QKV mode and error message; avoid computing scales in export path
- Message “QKV QDQS must be in the same type” is unclear. Make it explicit: “Q, K, and V quantizers must all be dynamic or all be static.”
- Skip _get_amax when exporting; it’s unused at runtime and can be None for dynamic. Guard by torch.onnx.is_in_onnx_export().
- if ( + if ( not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic ): - q_quantized_scale = self.q_bmm_quantizer._get_amax(query) - k_quantized_scale = self.k_bmm_quantizer._get_amax(key) - v_quantized_scale = self.v_bmm_quantizer._get_amax(value) + if not torch.onnx.is_in_onnx_export(): + q_quantized_scale = self.q_bmm_quantizer._get_amax(query) + k_quantized_scale = self.k_bmm_quantizer._get_amax(key) + v_quantized_scale = self.v_bmm_quantizer._get_amax(value) + else: + q_quantized_scale = k_quantized_scale = v_quantized_scale = None else: assert ( self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic - ), "QKV QDQS must be in the same type" + ), "Q, K, and V quantizers must all be dynamic or all be static." q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None
133-137
: Using a private helper across modules; consider promoting to public APICalling _get_block_sizes_list from another module couples to a private method. Expose it as a public helper (e.g., get_block_sizes_list or a util function) to avoid brittle dependencies.
Would you like me to extract a small utility, e.g., modelopt/torch/quantization/utils/block_sizes.py:get_block_sizes_list(shape, block_sizes)?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/diffusers/quantization/config.py
(1 hunks)modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(5 hunks)modelopt/torch/quantization/plugins/diffusers.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/diffusers/quantization/config.py
- modelopt/torch/quantization/export_onnx.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
_get_amax
(540-549)_get_block_sizes_list
(651-672)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
🔇 Additional comments (2)
modelopt/torch/quantization/plugins/diffusers.py (2)
151-155
: Default disable flag may be surprisingYou pass True when _disable_fp8_mha is absent, which disables FP8 MHA by default. Confirm this is intended for evaluation builds, or flip default to False.
Would you like a config flag gate to avoid silent disablement in production?
251-269
: Verified — export_fp8_mha signature matches callsiteDefinition in modelopt/torch/quantization/export_onnx.py (def export_fp8_mha at ~line 420) includes q_block_shape, k_block_shape, v_block_shape; the call in modelopt/torch/quantization/plugins/diffusers.py passes them — no mismatch found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
703-716
: Fix scaled_e4m3 call-site argument orderscaled_e4m3 now expects block_sizes before bias; update all callers to the new signature.
- modelopt/torch/quantization/nn/modules/tensor_quantizer.py:707 (new usage)
- modelopt/torch/quantization/calib/histogram.py:311
- tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py:55
- tests/gpu/torch/quantization/test_tensor_quant_cuda.py:148, 158, 166, 173, 185, 187, 202
Use: scaled_e4m3(inputs, amax, block_sizes, bias, E, M, ...). If no block_sizes, pass None as the third argument and move bias to the fourth.
♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
959-963
: Guarantee cleanup of _original_input_shape on exceptions.Deletion is unconditional but not exception-safe; wrap the quantization region in try/finally. Also fix the typo “isnt” → “isn't”.
- if ( + cleanup_original_input_shape = False + if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): - # Reshape is required if the logic isnt handled in the simulation kernel + # Reshape is required if the logic isn't handled in the simulation kernel self._setup_for_blockquant(inputs) setattr(self, "_original_input_shape", inputs.shape) + cleanup_original_input_shape = True inputs = self._process_for_blockquant(inputs) - outputs = inputs + try: + outputs = inputs @@ - if ( + if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant - ): - outputs = self._reset_to_original_shape(outputs) - - if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") - return outputs + ): + outputs = self._reset_to_original_shape(outputs) + return outputs + finally: + if cleanup_original_input_shape and hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape")Also applies to: 1002-1003
🧹 Nitpick comments (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
651-673
: Type-safety and input validation for _get_block_sizes_list.Add explicit typing, validate keys, and guard length mismatches to avoid silently passing malformed shapes to downstream ONNX ops.
-def _get_block_sizes_list(self, shape): +from typing import Sequence + +def _get_block_sizes_list(self, shape: Sequence[int] | torch.Size) -> list[int] | None: @@ - block_sizes_list = [] - for dim in range(len(shape)): + # Only allow integer axes plus known metadata keys. + valid_meta = {"type", "scale_bits", "scale_block_sizes"} + assert all( + isinstance(k, int) or k in valid_meta for k in self.block_sizes.keys() + ), f"Invalid block_sizes keys: {list(self.block_sizes.keys())}" + + rank = len(shape) + block_sizes_list: list[int] = [] + for dim in range(rank): # Check both positive and negative dimension indices - dim_negative = dim - len(shape) + dim_negative = dim - rank block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) block_sizes_list.append(block_size if block_size is not None else 1) return block_sizes_listmodelopt/torch/quantization/export_onnx.py (2)
238-265
: Validate block_shape rank and surface clearer errors.Add a rank check before emitting TRT_DynamicQuantize; mis-sized block_shapes currently fall through to TRT with cryptic errors.
def _fp8_block_quantize( @@ - input_type = inputs.type().scalarType() + input_type = inputs.type().scalarType() + rank = symbolic_helper._get_tensor_rank(inputs) + assert rank is not None, "Input rank must be known at export time." + assert len(block_sizes) == rank, ( + f"block_shape length ({len(block_sizes)}) must match input rank ({rank})." + ) @@ quantized_output, scales_output = g.op( "trt::TRT_DynamicQuantize", inputs, block_shape_i=block_sizes,
503-509
: Block-shape consistency in FP8 MHA path.Validate q/k/v block shapes match input ranks; also ensure softmax path never receives a block shape.
- query_scaled = export_fp8( - g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape - ) + assert (q_block_shape is None) or ( + len(q_block_shape) == symbolic_helper._get_tensor_rank(query_scaled) + ), "q_block_shape rank mismatch." + query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape) @@ - key_transposed_scaled = export_fp8( - g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape - ) + assert (k_block_shape is None) or ( + len(k_block_shape) == symbolic_helper._get_tensor_rank(key_transposed_scaled) + ), "k_block_shape rank mismatch." + key_transposed_scaled = export_fp8(g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape) @@ - # We cannot do block quant for the softmax's output - attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None) + # We cannot do block quant for the softmax's output + attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None) @@ - value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape) + assert (v_block_shape is None) or ( + len(v_block_shape) == symbolic_helper._get_tensor_rank(value) + ), "v_block_shape rank mismatch." + value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)Also applies to: 535-549
modelopt/torch/quantization/plugins/diffusers.py (2)
117-132
: Q/K/V quantization mode must match: improve error and skip redundant work.The assertion is good. Minor: clarify message and avoid computing per-tensor amax if any quantizer is dynamic.
- ), "QKV QDQS must be in the same type" + ), "Q/K/V quantization modes must match: either all dynamic or all static."
133-137
: Guard block size list creation when block_sizes is None._if a quantizer has no block_sizes, _get_block_sizes_list returns None; that’s fine. Add a quick comment to make intent explicit and future-proof.
- q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] + # Returns None for per-tensor paths; ONNX export handles that by taking the non-block path. + q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/diffusers/quantization/config.py
(1 hunks)modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(5 hunks)modelopt/torch/quantization/plugins/diffusers.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/diffusers/quantization/config.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
trt_high_precision_dtype
(407-409)trt_high_precision_dtype
(412-413)block_sizes
(289-291)block_sizes
(294-296)amax
(233-238)amax
(241-252)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
_get_amax
(540-549)_get_block_sizes_list
(651-672)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/torch/quantization/plugins/diffusers.py (1)
229-231
: Keep 't' for q/k/v scales — export extracts constsexport_onnx.py already extracts constant floats (uses sym_help._get_const / _maybe_get_const for scale/amax), so the current parse_args ("..., 't','t','t', ...") is fine; only change those three to 'f' if the export_fp8 const-extraction fix is removed. Location: modelopt/torch/quantization/plugins/diffusers.py (symbolic decorator around lines 229–231; same check applies to lines ~241–249).
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please share a sample command for ONNX exporting for a supported model as well in the description.
): | ||
# Tensor reshaping is required for static block quantization | ||
# Tensor shapes are handled separately by the quantization kernels for dynamic block quantization | ||
# Reshape is required if the logic isnt handled in the simulation kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: isn't
How do we check when the reshape is or isn't handled in the simulation kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernels are here: https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/modelopt/torch/quantization/src
The kernels only support MX format and reshaped NVFP4. Other formats require using Torch reshape. I think the previously comment made a false statement, just add one more comment.
@staticmethod | ||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") | ||
@symbolic_helper.parse_args( | ||
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this pattern is changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check parse_args:
https://github.com/pytorch/pytorch/blob/main/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py#L301
Each value represents the input type, which helps Torch trace the graph more effectively.
Since the block shape is a list of integers, we add three "is".
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
649-671
: Return None for “metadata‑only” block_sizes and add typing.Avoid forcing block mode with [1,…,1] when there are no per‑dim entries; this can route ONNX path incorrectly.
Apply:
- def _get_block_sizes_list(self, shape): + def _get_block_sizes_list(self, shape) -> list[int] | None: """Convert block_sizes dict to list format based on tensor shape. @@ - if self.block_sizes is None: + if self.block_sizes is None: return None - - block_sizes_list = [] + # If there are no integer dimension entries with a meaningful block size, treat as no block quant. + has_dim_sizes = any( + isinstance(k, int) and (v is not None and v != 1) for k, v in self.block_sizes.items() + ) + if not has_dim_sizes: + return None + + block_sizes_list: list[int] = [] for dim in range(len(shape)): # Check both positive and negative dimension indices dim_negative = dim - len(shape) block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) block_sizes_list.append(block_size if block_size is not None else 1) return block_sizes_list
1005-1007
: Centralize deletion of _original_input_shape in a finally block.Move this into a single finally at the end of forward so it runs regardless of success/failure.
Apply:
- if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") - return outputs + try: + return outputs + finally: + if hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape")modelopt/torch/quantization/export_onnx.py (3)
294-321
: Remove brittle default [1,1,128,1] and validate rank in _fp8_block_dequantize.Defaulting silently is dangerous and shape-dependent. Require explicit block_sizes and assert correctness. (Same concern raised earlier.)
Apply:
-def _fp8_block_dequantize( +def _fp8_block_dequantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, scales: torch.Value, trt_high_precision_dtype: str, otype: str | None = None, - block_sizes: list = [1, 1, 128, 1], + block_sizes: list, ): """Helper Function for Dequantization.""" output_shape = sym_help._get_tensor_sizes(inputs) + # Validate block shape + rank = sym_help._get_tensor_rank(inputs) + assert rank is not None, "Input rank must be known at export time." + assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, ( + f"block_shape length ({len(block_sizes)}) must match input rank ({rank})." + ) + assert all(isinstance(b, int) and b > 0 for b in block_sizes), ( + "All entries in block_shape must be positive integers." + ) + if otype is None: + otype = inputs.type().scalarType()
323-345
: Handle non-Python amax safely and validate block_shapes before block Q/DQ path.float(amax) will break when amax is a graph Value/0‑dim tensor; also assert block_shapes align with input rank before calling block ops. (Echoing prior comment.)
Apply:
def export_fp8( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, - amax: float | None, + amax: float | None, trt_high_precision_dtype: str | None, block_sizes: list | None, ): """Export quantized model to FP8 ONNX.""" - scale = 1.0 if amax is None else 448.0 / float(amax) + if amax is None: + scale = 1.0 + else: + amax_const = sym_help._get_const(amax, "f", "amax") + # If not a constant at export time, fall back to neutral scale to avoid exporter errors. + scale = 1.0 if (amax_const is None or amax_const == 0) else 448.0 / float(amax_const) @@ - if not block_sizes: + if not block_sizes: q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype) return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype) else: + # Validate block shape early + rank = sym_help._get_tensor_rank(inputs) + assert rank is not None, "Input rank must be known at export time." + assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, ( + f"block_shape length ({len(block_sizes)}) must match input rank ({rank})." + ) + assert all(isinstance(b, int) and b > 0 for b in block_sizes), ( + "All entries in block_shape must be positive integers." + ) q_tensor, scales_output = _fp8_block_quantize( g, inputs, trt_high_precision_dtype, block_sizes ) return _fp8_block_dequantize( g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes )
238-265
: Validate block_shape against input rank and values in _fp8_block_quantize.Guard against mismatched ranks and non-positive entries to avoid invalid custom op attributes at export time.
Apply:
def _fp8_block_quantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, trt_high_precision_dtype: str, block_sizes: list, ): """Helper Function for Quantization.""" output_shape = sym_help._get_tensor_sizes(inputs) + # Validate block shape + rank = sym_help._get_tensor_rank(inputs) + assert rank is not None, "Input rank must be known at export time." + assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, ( + f"block_shape length ({len(block_sizes)}) must match input rank ({rank})." + ) + assert all(isinstance(b, int) and b > 0 for b in block_sizes), ( + "All entries in block_shape must be positive integers." + )
🧹 Nitpick comments (2)
modelopt/torch/quantization/export_onnx.py (1)
512-518
: Pre-validate q/k block_shapes vs tensor ranks to fail fast.Catch mismatches early instead of deep inside TRT ops.
Apply:
- query_scaled = export_fp8( + # Sanity-check block shapes + for name, t, bs in (("q", query_scaled, q_block_shape), ("k", key_transposed_scaled, k_block_shape)): + if bs is not None: + r = sym_help._get_tensor_rank(t) + assert r is not None and len(bs) == r, f"{name}_block_shape must match rank ({r})." + query_scaled = export_fp8( g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape ) @@ - key_transposed_scaled = export_fp8( + key_transposed_scaled = export_fp8( g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape )modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
962-965
: Guarantee cleanup of _original_input_shape via try/finally; fix typo.Ensure attribute is deleted even on exceptions; also fix “isnt” -> “isn't”.
Apply:
- # Reshape is required if the logic isnt handled in the simulation kernel - self._setup_for_blockquant(inputs) - setattr(self, "_original_input_shape", inputs.shape) - inputs = self._process_for_blockquant(inputs) + # Reshape is required if the logic isn't handled in the simulation kernel + cleanup_original_shape = False + try: + self._setup_for_blockquant(inputs) + setattr(self, "_original_input_shape", inputs.shape) + cleanup_original_shape = True + inputs = self._process_for_blockquant(inputs) + except Exception: + # Make sure we don't leak transient attributes on failure + if cleanup_original_shape and hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape") + raise
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
trt_high_precision_dtype
(405-407)trt_high_precision_dtype
(410-411)block_sizes
(287-289)block_sizes
(292-294)amax
(231-236)amax
(239-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/quantization/export_onnx.py (2)
545-549
: Softmax stays non-block — good call.Keeping softmax output per-tensor avoids mismatched block semantics across BMM2.
558-559
: Block‑quantizing V path looks correct.The cast order aligns with the earlier Q/K path.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
701-714
: Pass block_sizes to scaled_e4m3 by keyword to avoid positional drift.rg found only call sites (tests and histogram.py); I couldn't locate the scaled_e4m3 definition—confirm it accepts a block_sizes kwarg. If it does, apply:
outputs = scaled_e4m3( inputs, amax, self._get_bias(inputs), E, M, self._trt_high_precision_dtype, self._pass_through_bwd, - block_sizes_list, + block_sizes=block_sizes_list, )
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
649-671
: Don’t force “block mode” when block_sizes has only metadata; return None.Current logic returns [1, ...] even if block_sizes has no per‑dim sizes, unintentionally enabling block handling downstream. Return None when there are no integer keys with v not in {None, 1}. Add typing while here.
- def _get_block_sizes_list(self, shape): + def _get_block_sizes_list(self, shape) -> list[int] | None: @@ - if self.block_sizes is None: - return None - - block_sizes_list = [] + if self.block_sizes is None: + return None + # If there are no per-dimension entries with a real block size, treat as no block quant. + has_dim_sizes = any( + isinstance(k, int) and (v is not None and v != 1) + for k, v in self.block_sizes.items() + ) + if not has_dim_sizes: + return None + + block_sizes_list: list[int] = [] for dim in range(len(shape)): # Check both positive and negative dimension indices dim_negative = dim - len(shape) block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) block_sizes_list.append(block_size if block_size is not None else 1) return block_sizes_list
1006-1007
: Guarantee deletion of _original_input_shape with try/finally (prev feedback).Move deletion into a finally block guarded by the local flag so it always runs.
- if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") - return outputs + try: + return outputs + finally: + if cleanup_original_input_shape and hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape")
🧹 Nitpick comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
962-966
: Make _original_input_shape cleanup exception‑safe: set flag when set, clean in finally.Set a local flag when creating the attribute so the finally block can reliably clean it even if an exception occurs later.
if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): # Reshape is required if the logic is not handled in the simulation kernel # Only MX format and NVFP4 reshape are currently supported by the kernel. self._setup_for_blockquant(inputs) - setattr(self, "_original_input_shape", inputs.shape) + setattr(self, "_original_input_shape", inputs.shape) + cleanup_original_input_shape = True inputs = self._process_for_blockquant(inputs)Add the flag near the top of forward (before this block):
- # Rotating the input + cleanup_original_input_shape = False + # Rotating the input
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
701-704
: LGTM: using pre‑reshape shape to derive per‑dim blocks.Using _original_input_shape avoids mismapping after reshape/flatten.
705-714
: Pass block_sizes as a keyword; confirm scaled_e4m3 signature and update callersChange this call to pass block_sizes by name to avoid positional-argument drift; before merging, confirm the scaled_e4m3 definition accepts a named block_sizes parameter (or update all callers if signature changed). Location: modelopt/torch/quantization/nn/modules/tensor_quantizer.py (around lines 705–714).
outputs = scaled_e4m3( inputs, amax, self._get_bias(inputs), E, M, self._trt_high_precision_dtype, self._pass_through_bwd, - block_sizes_list, + block_sizes=block_sizes_list, )Quick verification commands to run locally:
- rg -nP 'def\s+scaled_e4m3\s*(' -C2
- rg -nP '\bscaled_e4m3\s*(' -C2
- rg -nP '\bscaled_e4m3\s*([^)]block_sizes\s=' -n
Signed-off-by: Jingyu Xin <[email protected]>
What does this PR do?
Type of change: new feature
Overview:
Usage
Testing
evaluation
feature since the TRT kernel isn’t ready. No test cases are required at this time, we will add the test case next month.Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements